package com.example.security;

import com.example.util.EnhancedJwtTokenUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import jakarta.servlet.http.HttpServletRequest;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.time.LocalDateTime;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;

/**
 * 会话劫持防护组件
 */
@Slf4j
@Component
@RequiredArgsConstructor
public class SessionHijackingProtection {

    private final RedisTemplate<String, Object> redisTemplate;
    private final EnhancedJwtTokenUtil jwtTokenUtil;

    /**
     * 验证会话安全性
     */
    public boolean validateSession(String token, HttpServletRequest request) {
        try {
            if (!StringUtils.hasText(token)) {
                return false;
            }

            // 获取用户信息
            String username = jwtTokenUtil.getUsernameFromToken(token);
            if (username == null) {
                return false;
            }

            // 生成设备指纹
            String deviceFingerprint = generateDeviceFingerprint(request);
            
            // 检查设备指纹是否匹配
            if (!validateDeviceFingerprint(username, deviceFingerprint)) {
                log.warn("设备指纹不匹配，可能存在会话劫持: user={}, ip={}", 
                        username, getClientIp(request));
                return false;
            }

            // 检查IP地址变化
            if (!validateIpAddress(username, getClientIp(request))) {
                log.warn("IP地址异常变化，可能存在会话劫持: user={}, ip={}", 
                        username, getClientIp(request));
                return false;
            }

            // 检查User-Agent变化
            if (!validateUserAgent(username, request.getHeader("User-Agent"))) {
                log.warn("User-Agent异常变化，可能存在会话劫持: user={}", username);
                return false;
            }

            // 更新会话信息
            updateSessionInfo(username, deviceFingerprint, getClientIp(request), 
                            request.getHeader("User-Agent"));

            return true;

        } catch (Exception e) {
            log.error("验证会话安全性异常: {}", e.getMessage(), e);
            return false;
        }
    }

    /**
     * 生成设备指纹
     */
    private String generateDeviceFingerprint(HttpServletRequest request) {
        StringBuilder fingerprint = new StringBuilder();
        
        // User-Agent
        String userAgent = request.getHeader("User-Agent");
        if (userAgent != null) {
            fingerprint.append(userAgent);
        }
        
        // Accept-Language
        String acceptLanguage = request.getHeader("Accept-Language");
        if (acceptLanguage != null) {
            fingerprint.append(acceptLanguage);
        }
        
        // Accept-Encoding
        String acceptEncoding = request.getHeader("Accept-Encoding");
        if (acceptEncoding != null) {
            fingerprint.append(acceptEncoding);
        }
        
        // Accept
        String accept = request.getHeader("Accept");
        if (accept != null) {
            fingerprint.append(accept);
        }

        // 生成MD5哈希
        return generateMD5Hash(fingerprint.toString());
    }

    /**
     * 验证设备指纹
     */
    private boolean validateDeviceFingerprint(String username, String currentFingerprint) {
        try {
            String key = "session_fingerprint:" + username;
            String storedFingerprint = (String) redisTemplate.opsForValue().get(key);
            
            if (storedFingerprint == null) {
                // 首次登录，存储指纹
                redisTemplate.opsForValue().set(key, currentFingerprint, 24, TimeUnit.HOURS);
                return true;
            }
            
            return storedFingerprint.equals(currentFingerprint);
            
        } catch (Exception e) {
            log.error("验证设备指纹异常: {}", e.getMessage());
            return true; // 异常情况下允许通过，避免影响正常用户
        }
    }

    /**
     * 验证IP地址
     */
    private boolean validateIpAddress(String username, String currentIp) {
        try {
            String key = "session_ip:" + username;
            String storedIp = (String) redisTemplate.opsForValue().get(key);
            
            if (storedIp == null) {
                // 首次登录，存储IP
                redisTemplate.opsForValue().set(key, currentIp, 24, TimeUnit.HOURS);
                return true;
            }
            
            // 允许同一网段的IP变化（考虑NAT等情况）
            return isSameNetwork(storedIp, currentIp) || storedIp.equals(currentIp);
            
        } catch (Exception e) {
            log.error("验证IP地址异常: {}", e.getMessage());
            return true;
        }
    }

    /**
     * 验证User-Agent
     */
    private boolean validateUserAgent(String username, String currentUserAgent) {
        try {
            String key = "session_ua:" + username;
            String storedUserAgent = (String) redisTemplate.opsForValue().get(key);
            
            if (storedUserAgent == null) {
                // 首次登录，存储User-Agent
                redisTemplate.opsForValue().set(key, currentUserAgent, 24, TimeUnit.HOURS);
                return true;
            }
            
            // 允许轻微的User-Agent变化（版本更新等）
            return isSimilarUserAgent(storedUserAgent, currentUserAgent);
            
        } catch (Exception e) {
            log.error("验证User-Agent异常: {}", e.getMessage());
            return true;
        }
    }

    /**
     * 更新会话信息
     */
    private void updateSessionInfo(String username, String fingerprint, String ip, String userAgent) {
        try {
            Map<String, Object> sessionInfo = new HashMap<>();
            sessionInfo.put("fingerprint", fingerprint);
            sessionInfo.put("ip", ip);
            sessionInfo.put("userAgent", userAgent);
            sessionInfo.put("lastAccess", LocalDateTime.now().toString());
            
            String key = "session_info:" + username;
            redisTemplate.opsForValue().set(key, sessionInfo, 24, TimeUnit.HOURS);
            
        } catch (Exception e) {
            log.error("更新会话信息异常: {}", e.getMessage());
        }
    }

    /**
     * 检查是否为同一网段
     */
    private boolean isSameNetwork(String ip1, String ip2) {
        try {
            if (ip1 == null || ip2 == null) {
                return false;
            }
            
            // 简单的同网段检查（前3段相同）
            String[] parts1 = ip1.split("\\.");
            String[] parts2 = ip2.split("\\.");
            
            if (parts1.length != 4 || parts2.length != 4) {
                return false;
            }
            
            return parts1[0].equals(parts2[0]) && 
                   parts1[1].equals(parts2[1]) && 
                   parts1[2].equals(parts2[2]);
                   
        } catch (Exception e) {
            return false;
        }
    }

    /**
     * 检查User-Agent是否相似
     */
    private boolean isSimilarUserAgent(String ua1, String ua2) {
        if (ua1 == null || ua2 == null) {
            return ua1 == ua2;
        }
        
        // 提取主要特征进行比较
        String browser1 = extractBrowser(ua1);
        String browser2 = extractBrowser(ua2);
        
        String os1 = extractOS(ua1);
        String os2 = extractOS(ua2);
        
        return browser1.equals(browser2) && os1.equals(os2);
    }

    /**
     * 提取浏览器信息
     */
    private String extractBrowser(String userAgent) {
        if (userAgent == null) return "";
        
        String ua = userAgent.toLowerCase();
        if (ua.contains("chrome")) return "chrome";
        if (ua.contains("firefox")) return "firefox";
        if (ua.contains("safari")) return "safari";
        if (ua.contains("edge")) return "edge";
        if (ua.contains("opera")) return "opera";
        
        return "unknown";
    }

    /**
     * 提取操作系统信息
     */
    private String extractOS(String userAgent) {
        if (userAgent == null) return "";
        
        String ua = userAgent.toLowerCase();
        if (ua.contains("windows")) return "windows";
        if (ua.contains("mac")) return "mac";
        if (ua.contains("linux")) return "linux";
        if (ua.contains("android")) return "android";
        if (ua.contains("ios")) return "ios";
        
        return "unknown";
    }

    /**
     * 生成MD5哈希
     */
    private String generateMD5Hash(String input) {
        try {
            MessageDigest md = MessageDigest.getInstance("MD5");
            byte[] hashBytes = md.digest(input.getBytes());
            
            StringBuilder sb = new StringBuilder();
            for (byte b : hashBytes) {
                sb.append(String.format("%02x", b));
            }
            
            return sb.toString();
            
        } catch (NoSuchAlgorithmException e) {
            log.error("生成MD5哈希异常: {}", e.getMessage());
            return input.hashCode() + "";
        }
    }

    /**
     * 获取客户端IP
     */
    private String getClientIp(HttpServletRequest request) {
        String ip = request.getHeader("X-Forwarded-For");
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("Proxy-Client-IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getHeader("WL-Proxy-Client-IP");
        }
        if (ip == null || ip.isEmpty() || "unknown".equalsIgnoreCase(ip)) {
            ip = request.getRemoteAddr();
        }
        
        if (ip != null && ip.contains(",")) {
            ip = ip.split(",")[0].trim();
        }
        
        return ip != null ? ip : "unknown";
    }

    /**
     * 强制注销用户会话
     */
    public void forceLogout(String username, String reason) {
        try {
            // 撤销所有token
            jwtTokenUtil.revokeAllUserTokens(username);
            
            // 清除会话信息
            redisTemplate.delete("session_fingerprint:" + username);
            redisTemplate.delete("session_ip:" + username);
            redisTemplate.delete("session_ua:" + username);
            redisTemplate.delete("session_info:" + username);
            
            log.warn("强制注销用户会话: user={}, reason={}", username, reason);
            
        } catch (Exception e) {
            log.error("强制注销用户会话异常: {}", e.getMessage(), e);
        }
    }

    /**
     * 检查会话是否存在异常
     */
    public boolean hasSessionAnomaly(String username) {
        try {
            String key = "session_anomaly:" + username;
            return redisTemplate.hasKey(key);
        } catch (Exception e) {
            log.error("检查会话异常状态失败: {}", e.getMessage());
            return false;
        }
    }

    /**
     * 标记会话异常
     */
    public void markSessionAnomaly(String username, String reason) {
        try {
            String key = "session_anomaly:" + username;
            Map<String, Object> anomaly = new HashMap<>();
            anomaly.put("reason", reason);
            anomaly.put("timestamp", LocalDateTime.now().toString());
            
            redisTemplate.opsForValue().set(key, anomaly, 1, TimeUnit.HOURS);
            
            log.warn("标记会话异常: user={}, reason={}", username, reason);
            
        } catch (Exception e) {
            log.error("标记会话异常失败: {}", e.getMessage(), e);
        }
    }
}
